import matplotlib
matplotlib.use('Agg')

import utils_for_train_and_eval

from find_coords_utils import *
import find_coords_utils
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import (VGG16, VGG19, InceptionV3, ResNet50,
                                           Xception)
from PIL import Image
from tensorflow.keras import layers, models
from tensorflow import keras
from tqdm import tqdm
import tensorflow as tf
import scipy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import csv
import cv2
import time
import sys
import os
import math
import argparse


ap = argparse.ArgumentParser()
ap.add_argument('-save_dir', '--save_dir', type=str,
                default='keras_model_simple')
ap.add_argument('-decay', '--decay', type=float, default=0.0)
ap.add_argument('-lr', '--lr', type=float, default=2e-5)
ap.add_argument('-proportion_kept', '--proportion_kept',
                type=float, default=1.0)
ap.add_argument('-num_epochs', '--num_epochs', type=int, default=8)
ap.add_argument('-batch_size', '--batch_size', type=int, default=24)
ap.add_argument('-dropout_rate', '--dropout_rate', type=float, default=0.5)
ap.add_argument('-checkpoint', '--checkpoint', type=int, default=0)
ap.add_argument('-mode', '--mode', type=str, default='train')
ap.add_argument('-eval_csv', '--eval_csv', type=str,
                default='../data/test_6_2019.csv')
ap.add_argument('-save_cc_img', '--save_cc_img', action='store_true')
args = vars(ap.parse_args())


NOKILN_LABEL = 0
YESKILN_LABEL = 1
IMAGE_DIM = 224
INPUT_DIM = 224

IMG_SIZE = (224, 224, 3)


def create_model():
    model = models.Sequential()
    conv_base = VGG16(weights='imagenet',
                      include_top=False, input_shape=IMG_SIZE)
    conv_base = models.Model(conv_base.input, conv_base.layers[-2].output)
    conv_base.trainable = False
    model.add(conv_base)
    model.add(layers.AveragePooling2D(pool_size=(14, 14)))
    model.add(layers.Flatten())
    model.add(layers.Dense(512, activation='relu'))
    if args['mode'] == 'train':
        model.add(layers.Dropout(args['dropout_rate']))
    model.add(layers.Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy',
                  optimizer=tf.keras.optimizers.RMSprop(
                      lr=args['lr'], decay=args['decay']),
                  metrics=['acc'])
    return model


def imgs_input_fn(filenames, labels=None, perform_shuffle=False, repeat_count=1, batch_size=1, input_name='', training=True):
    def _parse_function(filename, label):
        image_string = tf.read_file(filename)
        image = tf.image.decode_image(image_string, channels=3)
        image.set_shape([None, None, None])
        image = tf.image.resize_images(image, IMG_SIZE[:2])
        image = tf.subtract(image, 116.779)  # Zero-center by mean pixel
        image.set_shape(IMG_SIZE)
        image = tf.reverse(image, axis=[2])  # 'RGB'->'BGR'
        d = dict(zip([input_name], [image])), label
        return d

    if labels is None:
        labels = [0] * len(filenames)
    labels = np.array(labels)

    # Expand the shape of 'labels' if necessory
    if len(labels.shape) == 1:
        labels = np.expand_dims(labels, axis=1)
        filenames = tf.constant(filenames)
        labels = tf.constant(labels)
        labels = tf.cast(labels, tf.float32)
        dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
        dataset = dataset.map(_parse_function)

    dataset = dataset.batch(batch_size)  # Batch size to use
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels


def main():
    # Prepare the data
    all_tiles = pd.read_csv(args['eval_csv'])
    # Get only yeskilns
    yeskiln_tiles = all_tiles.loc[all_tiles['prediction'] == 'yeskiln']
    test_files = yeskiln_tiles['path'].tolist()
    test_labels = [YESKILN_LABEL] * len(test_files)

    # Prepare to load the trained model
    model_dir = os.path.join(os.getcwd(), args['save_dir'])
    model = create_model()
    print model.summary()

    # Prepare to load the trained model, but a version that
    # allows us to access output from last convolutional layer
    model_cam = models.Sequential()
    Network = MODELS[args['model']]
    conv_base = Network(weights='imagenet',
                        include_top=False, input_shape=IMG_SIZE)
    if args['model'] == 'vgg16_cam':
        conv_base = models.Model(conv_base.input, conv_base.layers[-2].output)
    else:
        conv_base = models.Model(conv_base.input, conv_base.layers[-4].output)
    conv_base.trainable = False
    model_cam.add(conv_base)
    model_cam.compile(loss='binary_crossentropy',
                      optimizer=tf.keras.optimizers.RMSprop(
                          lr=args['lr'], decay=args['decay']),
                      metrics=['acc'])
    print model_cam.summary()

    # Load the models as estimators (which allow us to perform eval)
    input_name = model.input_names[0]
    kiln_model = tf.keras.estimator.model_to_estimator(
        keras_model=model, model_dir=model_dir)
    kiln_model_cam = tf.keras.estimator.model_to_estimator(
        keras_model=model_cam, model_dir=model_dir)

    checkpoint_path = ''
    if args['checkpoint'] != 0:
        checkpoint_path = os.path.join(
            args['save_dir'], 'model.ckpt-' + str(args['checkpoint']))

    # Outputs probabilities, like in eval()
    print('Running inference on the images to get final output')
    results = []
    for result in tqdm(kiln_model.predict(input_fn=lambda: imgs_input_fn(test_files,
                                                                     labels=test_labels,
                                                                     perform_shuffle=False,
                                                                     batch_size=1,
                                                                     input_name=input_name,
                                                                     training=False),
                                      checkpoint_path=checkpoint_path)):
        results.append(result)

    # Outputs from the last convolutional layer
    print('Running inference on the images to get last conv layer output')
    results_cam = []
    for result_cam in tqdm(kiln_model_cam.predict(input_fn=lambda: imgs_input_fn(test_files,
                                                                             labels=test_labels,
                                                                             perform_shuffle=False,
                                                                             batch_size=1,
                                                                             input_name=model_cam.input_names[0],
                                                                             training=False),
                                              checkpoint_path=checkpoint_path)):
        results_cam.append(result_cam)

    img_path_to_mask = find_coords_utils.mask_images(
        results, results_cam, test_files, kiln_model, model_type=args['model'])

    # Set up the graph
    graph = Graph(yeskiln_tiles.shape[0], yeskiln_tiles)
    for i, yeskiln in yeskiln_tiles.iterrows():
        vertex = Vertex(yeskiln['row'], yeskiln['col'], yeskiln['new_id'])
        neighbors = graph.get_neighbors(vertex)
        for neighbor in neighbors:
            graph.add_edge(vertex, neighbor)
        if not neighbors:
            graph.add_edge(vertex, None)

    print('Finding and going over the connected components')
    connected_comps = graph.find_connected_components()
    new_ids = []
    labels = []
    recentered_file_paths = []
    centroids = []
    coordinates = []
    source_cc_path = [] # records cc for each coord
    cc_sizes = []
    for i, cc in enumerate(tqdm(connected_comps)):
        image_of_full_area, image_of_full_area_mask, lat, lon = find_coords_utils.create_connected_component_image(cc, yeskiln_tiles, all_tiles, img_path_to_mask)

        ### Save images of the full area, masked full area ###
        rows = [v.row for v in cc]
        cols = [v.col for v in cc]
        image_of_full_area_path = 'test_imgs/' + \
            str(min(rows)) + ',' + str(min(cols)) + '_to_' + \
            str(max(rows)) + ',' + str(max(cols))
        image_of_full_area.save(image_of_full_area_path + '.jpeg')
        image_of_full_area_mask.save(image_of_full_area_path + '_mask.jpeg')

        image_of_full_area_mask = find_coords_utils.postprocess_connected_component_mask(image_of_full_area_mask)
        image_of_full_area_mask.save(image_of_full_area_path + '_mask_postprocessed.jpeg')

        shape_model = None
        shape_input_name = None
        shape_checkpoint_path = ''

        full_area_mask_arr = np.array(image_of_full_area_mask, dtype=np.uint8)
        _centroids, coords, cc_sizes, recentered_paths, _new_ids, _labels = find_coords_utils.get_coordinates(
            image_of_full_area, full_area_mask_arr, lat, lon, shape_model, shape_input_name, shape_checkpoint_path, cc_sizes=cc_sizes)
        new_ids.extend(_new_ids)
        labels.extend(_labels)
        recentered_file_paths.extend(recentered_paths)
        centroids.extend(_centroids)
        coordinates.extend(coords)
        source_cc_path.extend([image_of_full_area_path] * len(coords))

        ### Save image of points overlayed on masked image ###
        if args['save_cc_img']:
            plt.figure(i)
            y = [c[0] for c in centroids]
            x = [c[1] for c in centroids]
            plt.scatter(x=x, y=y, c='r', s=20)
            img = plt.imread(image_of_full_area_path + '_mask_postprocessed.jpeg')
            plt.imshow(img, zorder=0)
            plt.show()
            image_of_full_area_points_path = 'test_imgs/' + str(min(rows)) + ',' + str(
                min(cols)) + '_to_' + str(max(rows)) + ',' + str(max(cols)) + '_points.jpeg'
            plt.savefig(image_of_full_area_points_path)
            plt.close(i)


    coords_file_path = args['eval_csv'].split('.csv')[0] + '_coords.csv'
    with open(coords_file_path, 'wb') as f:
        wr = csv.writer(f, delimiter=',')
        wr.writerow(['new_id', 'lat', 'long', 'x', 'y', 'source_cc', 'path', 'label'])
        for i, coords in enumerate(coordinates):
            #also write corresponding image
            wr.writerow([new_ids[i], coords[0], coords[1], centroids[i][0], centroids[i][1], source_cc_path[i], recentered_file_paths[i], labels[i]])

    with open('avg_kiln_size.txt', 'w') as f:
        mean_cc_size = np.mean(np.array(cc_sizes))
        f.write(str(mean_cc_size))


if __name__ == '__main__':
    main()
